Neuron Recycling

Authors
Affiliations

Jakub Krajewski

IDEAS NCBR

University of Warsaw

Maciej Pióro

IDEAS NCBR

University of Warsaw

Sebastian Jaszczur

IDEAS NCBR

University of Warsaw

Marek Cygan

University of Warsaw

Published

July 7, 2023

Sparse neural networks have garnered attention due to their theoretical promise of lowered computational demands and memory savings. However, to this date, the theoretical gains have largely failed to materialize due to the lack of hardware support for this kind of models. In this work, we explore the idea of neuron recycling which is inspired by pruning - a method often employed to induce sparsity in neural networks. We also present lessons we have learned along the way.

Introduction

Pruning is a well-established technique used to sparsify neural networks. It relies on the fact that typically a large part of a trained neural network can be masked without impacting the accuracy of the network, albeit often requiring additional fine-tuning in order to regain some lost performance. Despite multiple works proposing various neuron-selection criteria for pruning, magnitude-based pruning remains a viable option. The Lottery Ticket Hypothesis is a major finding on the way to explain how the initialization impacts neural networks. The main point of the LTH is that through iterative pruning, performant subnetworks depending on the initialization can be found in neural networks. Those well-initialized network fragments are the namesake of LTH (the “lottery tickets”). Although some notions of the original LTH paper have been challenged, it has remained the subject of active research and a motivation for our work. By combining the two ideas (pruning and LTH) we arrive at a new potential technique for raising neural network performance. If we are able to remove parts of the network without hurting the performance (pruning) and the fitness of a part of a network is determined at initialization, perhaps we could re-initialize the unnecessary network parts (i. e. draw more “lottery tickets”), leading to a better-performing network.

Preliminaries

Before we move to the presentation of our experiments and findings, let’s first discuss the training setup, define key terminology, and go over the basics.

Model and training setup

In our project, we are focusing on the Transformer, since it’s a major architecture across different domains . For the specific type of the model, we are working on encoder-only BERT . Taking into consideration available computational resources and expected iteration time (we wanted to try as many options as possible), we decided to opt for the BERT Medium configuration (with \(d_\text{model}=512\) and \(8\) attention heads). We focus on the feed-forward layer, because it is the most computationally demanding part of commonly-used transformer models and, in large models, it contains the majority of the parameters. At the same time, the amount of research focusing on the attention mechanism is overwhelming, suggesting that the feed-forward layer is a relatively unexplored area.

We trained the model for \(80{,}000\) steps (around compute-optimal point), with Adam, using batch size of \(256\) and learning rate of \(0.001\).

It has been shown that the initialization of the weights in neural networks is of utmost importance to the convergence of the network. In all our experiments, we initialized the linear layers using the following distribution: as per .

In the following part of this post, we will often use the terms of neuron and magnitude. Below are the definitions we employ.

  • Neuron
    In the Transformer, feed-forward layer consists of two linear layers, with a nonlinearity in between. The first layer maps the input vector from \(d_\text{model}\) to \(d_\text{ff}\) dimension, and the second one from \(d_\text{ff}\) back to \(d_\text{model}\). Typically, \(d_\text{ff}\) is four times greater than \(d_\text{model}\). By neuron, we will understand all weights interacting with a particular coordinate in the hidden, \(\mathbb{R}^{d_\text{ff}}\) space. In the torch implementation, a neuron’s weights are the parameters in a given row of the first feed-forward matrix and in the corresponding column in the second one.

  • Magnitude
    To calculate magnitude of a weight, we will use its absolute value. As the magnitude of a neuron we will use value of the expression \(M=\sqrt{\sum{w_{in}^2} \sum{w_{out}^2}}\).

Pruning

Pruning is a technique used to induce sparsity and decrease the parameter count in a neural network. In simple terms, it means deleting the least important neurons (structured pruning) or weights (unstructured pruning). A typical implementation realizes this by either multiplying the output of the deleted neurons by 0 or setting the weights of the neuron to 0. A widely-used proxy for the importance of a neuron or weight is its magnitude. You may find it counterintuitive, but we can even remove the whole FF Layer and the model will still work. This is because model can learn to represent the same transformation using attention.

Below we present a plot with loss curves of the model gradually pruned at the FF layer, starting in step \(10{,}000\), such that the layer is completely masked in the end of the training. As a comparison, we also add regular model and the one without feed-forward layer.

Interestingly, the effect of pruning can’t be visible for a significant fraction of the training time. It’s also worth noting that in the end the model without FF Layer performs slightly better than the pruned one. This is because in the first case, Attention was trained to adjust  from the very beginning of the training.

The goal

The end-goal of the project was to create a method that would allow us to make better use of the parameters in the feed-forward layer. In this context, a natural question arises - against what baseline should our results be compared? To answer this question, we trained the model with differing dimensionalities of the feed-forward layer. The results are presented below. The true BERT Medium configuration has \(d_\text{ff}=2048\), and, as expected, the model’s performance drops when the \(d_\text{ff}\) is decreased and increases when the \(d_\text{ff}\) is increased.

Understanding neuron magnitudes

One of the key inspirations for our work was structured pruning, where neuron magnitude is often chosen as the measure of significance. We were interested in how this metric evolves during the training process. At first, we thought a histogram of neuron magnitudes would exhibit a normal distribution. However, our experiments showed something different. The following graph shows evolution of neuron magnitudes throughout the training process.


In the early stages of training, it’s clear to see the neurons split into two unique groups, one featuring much lower magnitudes than the other. This finding opens up a multitude of discussion topics. It can be speculated that the neurons belonging to the group with smaller magnitudes potentially don’t hold much importance. There’s an option to consider eliminating them. However, it’s also possible that these minor neurons, though activated seldom, play a critical role in specific tasks.

Notably, we noticed that this discrepancy does not happen in the last layers of the network. As an example, below you can examine magnitudes for neurons in the 8th FF layer of the network.


After examining these experiments, we were trying to understand why in the early layers we observed two distinct groups of neurons, categorized by their magnitudes. One possible explanation is that certain parts of the network receive a smaller signal and are slower to improve in training. We designed an experiment to check that. We periodically froze all parts of the network except for the feed-forward component and continued to train it for several batches of data. We hypothesized that in this scenario, weaker neurons might catch up, resulting in a more even distribution. We called this procedure overtraining  feed-forward layer. It’s important to note that this approach is impractical and computationally heavy, but we wanted to use it for the purpose of illustration. The results are depicted in the following plot.


We can see that the group of “weaker” neurons has moved to the right after performing additional training of the FF part. However, neurons still form two distinct groups: overtraining the whole layer is not enough for the weaker ones to catch up. In the next experiment, we have examined the scenario of retraining only small magnitude neurons, only large magnitude neurons and random subsets. How does it affect the performance? The results are depicted on the following plot.

Overtraining only the smallest neurons yields the best results when compared to reinforcing high-magnitude ones. Notably, overtraining the small ones gives similar gains in performance to working on the entire layer! Contrarily, ampifying the highest ones gives gains similar to no overtraining at all. This provides a compelling argument in favor of our technique.

So far, we have performed a series of experiments on relatively small architectures. We were curious to see how our observations would translate to well-established, large-scale foundation models like BERT Large or T5.

Upon examining the magnitudes in these foundation models, noticeable differences emerge. The magnitudes in T5 seem similar to those in our smaller models, while BERT and GPT-2 display more favorable distributions. What could account for these variations? We discovered that the use of weight decay plays a significant role. This simple but widely used technique has a considerable impact on the distribution phenomenon we’ve been investigating.

These findings support the idea of exploring neuron recycling more thoroughly and offer a solid foundation for further experiments. In subsequent sections, we will delve into the results of these investigations and share our insights.

Recycling

The central part of our work was a method we called neuron recycling. The whole process boils down to three phases, repeated periodically: training, selection and reinitialization.

  • In the training phase, the model is trained to predict masked tokens (masked language modelling).
  • In the selection phase, the least important neurons are determined, where the baseline criterion is neuron magnitude
  • In the reinitialization phase, new weights are assigned to neurons.

Although this procedure is conceptually simple, it allows for many degrees of freedom. Here are some choices that can be made:

  • The number of training steps before consecutive selection / reinitialization phases
  • The percentage of recycled neurons
  • Selection / reinitialization strategies

After examining the pruning literature [citations] we have found that the simple magnitude-based approach works best in most cases . It is also easy to implement and computationally efficient. This approach is also grounded on our experiments. Below we present the training curves for the model pruned gradually using different criterions: high/low magnitude and random neurons.

As you can see, removing low magnitude neurons hurts the model the least, and removing high magnitude ones cases the largest loss. This is a good argument that this criterion correlates well with neuron significance.

Baseline recycling

The most straightforward reinitialization scheme is to sample the weights of the reinitialized neurons from the initial distribution. After examining the performance of this solution, we could not see any difference between recycling and vanilla training. As a sanity check, we have examined the histogram presenting the number of times each neuron was recycled, discovering that the same small subset of neurons was being reinitialized over and over during training. As we have seen , on average magnitude of neurons grows throughout the training. Therefore, the newly sampled neurons have, on average, lower magnitudes than non-recycled neurons and are unable to catch up to non-recycled neurons before another selection phase. Thus, the recycled neurons are caught up in a vicious cycle in which they are always recycled before achieving high magnitude.

Immunity

Another thing we tried to make reinitialization work was Immunity. The idea here is to encourage diverse recycling by making each recycled neuron immune to further recycling for some predefined number of steps. We hypothesized that a reinitialized neuron needs some time to grow, but in our initial setting, the neuron is recycled before enough growth happens, keeping the smaller neurons in a vicious cycle of periodically reoccurring reinitialization. Unfortunately, in our experiments, the reinitialized neurons failed to grow even if given immunity. At the same time immunity of the small neurons lead to recycling of the large neurons, hurting the model’s performance.

Modifying reinitialization weights distribution

As we have pointed out before, magnitude and weight distribution drifts away from the initial distribution as the training progresses. However, during our initial attempts, we initialized the weights sampling from the initial distribution. To fix this issue, we decided to try out another weight sampling technique. A feed-forward layer consists of two matrices. Each time a neuron is reinitialized, the corresponding row and column in the matrices are filled with new (reinitialized) weights. The new technique, when replacing the weights, used the normal distribution with mean and standard deviation equal to the mean and standard deviation of all the weights in the respective matrix. This approach, like immunity, eliminated the vicious cycle. However, this process introduced a lot of noise with adverse effect on the model’s loss.

Copying existing neurons

At this stage of our project, we gained interest in the topic of growing neural networks. Intuitively, this problem has a similar part, in which we want to add new neurons or weights. In the case of Large Language Models, the topic is mentioned in [Gopher]. The authors describe that for their experiments, copying existing neurons was the best strategy. We have also tried it, but without positive results.

Smooth recycling

In an effort to explain our past problems, we pointed out the discrete nature of our technology, as opposed to the continuous landscape of training/optimizers. To change this, we have modified our strategy to gradually change the value by linear interpolation. We have noticed that this way we were able to make training loss dynamics smoother, but without beating the baseline.

Tangent - Midpoint Loss

While working towards the main goal of the project, we started investigating the distribution of neuron magnitudes during the training. We noticed that the magnitude distribution is quite uneven - a large percentage of neurons remains small, and the distribution is right-skewed. Since the goal of our project was to reduce the number of low-quality, i.e., small neurons, we came up with a pretty risky solution: Midpoint Loss. The idea was to introduce an additional loss term that would encourage neuron growth and “even-ness” of the magnitude distribution. The general equation for the midpoint loss is:

\[ Loss = \sum_{l = 1}^{L} \sum_{n = 1}^{dff} dist(M_{l,n}, stop\_grad(avg(M_{l,*})))\] where:

  • \(M_{l,n}\) - magnitude of th \(n^{\text{th}}\) neuron in the \(l^{\text{th}}\) layer. In some experiments we used the \(log\) of the magnitude.
  • \(dist\) - distance function, typically \(l_1\), \(l_2\).
  • \(avg\) - arithmetic mean. In some experiments, median was used instead due to its robustness to outliers.
  • \(L\) - number of layers.
  • \(d_\text{ff}\) - number of neurons in a layer. In some experiments, we only summed over neurons with magnitude below the average magnitude of the layer, to encourage growth of small neurons.
  • \(stop\_grad\) - stops the gradient from flowing through. Without stopping the gradient, the loss gradient computation becomes untractable.

Since this idea is quite similar to weight decay, we decided not to optimize this term with Adam, but to split it from the task loss and optimize it using simple gradient descent - a similar technique is used in AdamW to incorporate weight decay loss term.

Midpoint loss achieved the goal of boosting the small neurons, however it failed to make a positive impact on the model’s performance.